# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
python main_watermark.py --model_name guanaco-7b \
    --prompt_type guanaco --prompt_path data.json --nsamples 10 --batch_size 16 \
    --method openai --method_detect openai --seeding hash --ngram 2 --scoring_method v2 --temperature 1.0 \
    --payload 0 --payload_max 4 \
    --output_dir output/

python main_watermark.py --model_name guanaco-7b \
    --prompt_type guanaco --prompt_path data.json --nsamples 10 --batch_size 16 \
    --method maryland --seeding hash --ngram 2 --gamma 0.25 --delta 2.0 \
    --payload 0 --payload_max 4 \
    --output_dir output/
"""

import argparse
import random
#import multiprocessing
#multiprocessing.set_start_method('spawn', force=True)
import string
import pickle
from typing import Dict, List
import os
import time
import json

import tqdm
import pandas as pd
import numpy as np

import torch
torch.multiprocessing.set_start_method('spawn', force=True)
from peft import PeftModel    
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer

from wm import (WmGenerator, OpenaiGenerator, OpenaiDetector, OpenaiNeymanPearsonDetector, 
                OpenaiDetectorZ, MarylandGenerator, MarylandDetector, MarylandDetectorZ)
import utils

import hyperparse
from wm.utools import *
import traceback

usermode, usermode_str = hyperparse.parse("usermode")
extramode, extramode_str = hyperparse.parse("extramode")
usermode.update(extramode)


def get_args_parser():
    parser = argparse.ArgumentParser('Args', add_help=False)

    # model parameters
    parser.add_argument('--model_name', type=str)

    # prompts parameters
    parser.add_argument('--prompt_path', type=str, default="data/alpaca_data.json")
    parser.add_argument('--prompt_type', type=str, default="alpaca", 
                        help='type of prompt formatting. Choose between: alpaca, oasst, guanaco')
    parser.add_argument('--prompt', type=str, nargs='+', default=None, 
                        help='prompt to use instead of prompt_path, can be a list')

    # generation parameters
    parser.add_argument('--temperature', type=float, default=0.8)
    parser.add_argument('--top_p', type=float, default=0.95)
    parser.add_argument('--max_gen_len', type=int, default=256)

    # watermark parameters
    parser.add_argument('--method', type=str, default='none', 
                        help='Choose between: none (no watermarking), openai (Aaronson et al.), maryland (Kirchenbauer et al.)')
    parser.add_argument('--method_detect', type=str, default='same',
                        help='Statistical test to detect watermark. Choose between: same (same as method), openai, openaiz, openainp, maryland, marylandz')
    parser.add_argument('--seeding', type=str, default='hash', 
                        help='seeding method for rng key generation as introduced in https://github.com/jwkirchenbauer/lm-watermarking')
    parser.add_argument('--ngram', type=int, default=4, 
                        help='watermark context width for rng key generation')
    parser.add_argument('--gamma', type=float, default=0.25, 
                        help='gamma for maryland: proportion of greenlist tokens')
    parser.add_argument('--delta', type=float, default=4.0, 
                        help='delta for maryland: bias to add to greenlist tokens')
    parser.add_argument('--hash_key', type=int, default=35317, 
                        help='hash key for rng key generation')
    parser.add_argument('--scoring_method', type=str, default='none', 
                        help='method for scoring. choose between: \
                        none (score every tokens), v1 (score token when wm context is unique), \
                        v2 (score token when {wm context + token} is unique')

    # multibit
    parser.add_argument('--payload', type=int, default=0, help='message')
    parser.add_argument('--payload_max', type=int, default=0, 
                        help='maximal message, must be inferior to the vocab size at the moment')

    # expe parameters
    parser.add_argument('--nsamples', type=int, default=None, 
                        help='number of samples to generate, if None, take all prompts')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--do_eval', type=utils.bool_inst, default=True)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--split', type=int, default=None,
                        help='split the prompts in nsplits chunks and chooses the split-th chunk. \
                        Allows to run in parallel. \
                        If None, treat prompts as a whole')
    parser.add_argument('--nsplits', type=int, default=None,
                        help='number of splits to do. If None, treat prompts as a whole')
    parser.add_argument('--nanchor', type=int, default=1,
                        help='number of anchor bits')
    parser.add_argument('--nchecksum', type=int, default=1,
                        help='number of checksum')

    # distributed parameters
    parser.add_argument('--ngpus', type=int, default=None)

    return parser


def format_prompts(prompts: List[Dict], prompt_type: str) -> List[str]:
    if prompt_type=='alpaca':
        PROMPT_DICT = {
            "prompt_input": (
                "Below is an instruction that describes a task, paired with an input that provides further context.\nWrite a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
            ),
            "prompt_no_input": (
                "Below is an instruction that describes a task, paired with an input that provides further context.\nWrite a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
            ),
        }
    elif prompt_type=='guanaco':
        PROMPT_DICT = {
            "prompt_input": (
                "A chat between a curious human and an artificial intelligence assistant.\nThe assistant gives helpful, detailed, and polite answers to the user's questions.\n\n### Human: {instruction}\n\n### Input:\n{input}\n\n### Assistant:"
            ),
            "prompt_no_input": (
                "A chat between a curious human and an artificial intelligence assistant.\nThe assistant gives helpful, detailed, and polite answers to the user's questions.\n\n### Human: {instruction}\n\n### Assistant:"
            )
        }
    prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
    prompts = [
        prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
        for example in prompts
    ]
    return prompts

def load_prompts(json_path: str, prompt_type: str, nsamples: int=None) -> List[str]:
    with open(json_path, "r") as f:
        prompts = json.loads(f.read())
    new_prompts = prompts
    # new_prompts = [prompt for prompt in prompts if len(prompt["output"].split()) > 5]
    new_prompts = new_prompts[:nsamples]
    print(f"Filtered {len(new_prompts)} prompts from {len(prompts)}")
    new_prompts = format_prompts(new_prompts, prompt_type)
    return new_prompts

def load_results(json_path: str, nsamples: int=None, result_key: str='result') -> List[str]:
    with open(json_path, "r") as f:
        if json_path.endswith('.json'):
            prompts = json.loads(f.read())
        else:
            prompts = [json.loads(line) for line in f.readlines()] # load jsonl
    new_prompts = [o[result_key] for o in prompts]
    new_prompts = new_prompts[:nsamples]
    return new_prompts


def main(args):
    if usermode:
        args.output_dir = os.path.join(args.output_dir, usermode_str)
        if "method" in usermode:
            args.method = usermode["method"]
            args.method_detect = usermode["method"]
        if "chars" in usermode:
            usermode["chars"] = str(usermode["chars"])

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # build model
    if args.model_name == "llama-7b":
        model_name = "huggyllama/llama-7b"
        adapters_name = None
    if args.model_name == "guanaco-7b":
        model_name = "huggyllama/llama-7b"
        adapters_name = 'timdettmers/guanaco-7b'
    elif args.model_name == "guanaco-13b":
        model_name = "huggyllama/llama-13b"
        adapters_name = 'timdettmers/guanaco-13b'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    args.ngpus = torch.cuda.device_count() if args.ngpus is None else args.ngpus
    if "detect" not in extramode and "eval" not in extramode:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            #max_memory={i: '32000MB' for i in range(args.ngpus)},
            #offload_folder="offload",
        )
        if adapters_name is not None:
            model = PeftModel.from_pretrained(model, adapters_name)
        model = model.eval()
        for param in model.parameters():
            param.requires_grad = False
        if args.ngpus > 0:
            model.to('cuda')
        print(f"Using {args.ngpus}/{torch.cuda.device_count()} GPUs - {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated per GPU")

        if "gkey" in usermode:
            args.hash_key = usermode["gkey"]
        if "wowm" in usermode or "oriwowm" in usermode:
            args.temperature = -1

        # build watermark generator
        if args.method == "none" or args.method is None:
            generator = WmGenerator(model, tokenizer)
        elif args.method == "openai":
            generator = OpenaiGenerator(model, tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, payload=args.payload, args = args)
        elif args.method == "maryland":
            generator = MarylandGenerator(model, tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, payload=args.payload, gamma=args.gamma, delta=args.delta)
        else:
            raise NotImplementedError("method {} not implemented".format(args.method))

    # load prompts
    if args.prompt is not None:
        prompts = args.prompt
        prompts = [{"instruction": prompt} for prompt in prompts]
    else:
        prompts = load_prompts(json_path=args.prompt_path, prompt_type=args.prompt_type, nsamples=args.nsamples)

    # do splits
    if args.split is not None:
        nprompts = len(prompts)
        left = nprompts * args.split // args.nsplits 
        right = nprompts * (args.split + 1) // args.nsplits if (args.split != args.nsplits - 1) else nprompts
        prompts = prompts[left:right]
        print(f"Creating prompts from {left} to {right}")
    
    # (re)start experiment
    os.makedirs(args.output_dir, exist_ok=True)
    start_point = 0 # if resuming, start from the last line of the file
    if False:
        if os.path.exists(os.path.join(args.output_dir, f"results.jsonl")):
            with open(os.path.join(args.output_dir, f"results.jsonl"), "r") as f:
                for _ in f:
                    start_point += 1
    print(f"Starting from {start_point}")

    # generate
    all_times = []
    if "detect" not in extramode and "eval" not in extramode:
        if "rkey" in usermode or "nchars" in usermode or "brick3" in usermode:
            exists = set()
        with open(os.path.join(args.output_dir, f"results.jsonl"), "w") as f:
            for ii in tqdm.tqdm(range(start_point, len(prompts), args.batch_size)):
                if "useori" in usermode:
                    continue
                if "rkey" in usermode:
                    if "seedbsz" in usermode:
                        generator.key01 = []
                        while len(generator.key01) < args.batch_size:
                            if len(generator.key01) % usermode["seedbsz"] == 0:
                                while True:
                                    nk = generator.add_anchorcheck(''.join(random.choice('01') for _ in range(int(generator.usermode["rkey"]))))
                                    if nk not in exists:
                                        exists.add(nk)
                                        break
                                generator.key01.append(nk)
                            else:
                                generator.key01.append(generator.key01[-1])
                    else:
                        while generator.key01 in exists:
                            generator.key01 = generator.add_anchorcheck(''.join(random.choice('01') for _ in range(int(generator.usermode["rkey"]))))
                        exists.add(generator.key01)
                if "nchars" in usermode or "brick3" in usermode:
                    alphabets = string.digits + string.ascii_lowercase + string.ascii_uppercase
                    if "seedbsz" in usermode:
                        generator.key01 = []
                        while len(generator.key01) < args.batch_size:
                            if len(generator.key01) % usermode["seedbsz"] == 0:
                                while True:
                                    sz = usermode["nchars"] if "nchars" in usermode else usermode["brick3"]
                                    nk = ''.join(random.choice(alphabets) for _ in range(int(sz)))
                                    if nk not in exists:
                                        exists.add(nk)
                                        break
                                generator.key01.append(nk)
                            else:
                                generator.key01.append(generator.key01[-1])
                    else:
                        while generator.key01 in exists:
                            generator.key01 = ''.join(random.choice(alphabets) for _ in range(int(usermode["nchars"])))
                        exists.add(generator.key01)
                if "mix" in usermode:
                    generator.key01 = []
                    while len(generator.key01) < args.batch_size:
                        if random.random() > usermode["mix"]:
                            generator.key01.append(None)
                        else:
                            nk = random.randrange(usermode.get("maxkey", 1000))
                            if "alt2" in usermode:
                                nk = [1, nk]
                            if "alt" in usermode:
                                nk = list(range(1, args.nanchor + 1)) + [nk]
                            generator.key01.append(nk)
                # generate chunk
                time0 = time.time()
                chunk_size = min(args.batch_size, len(prompts) - ii)
                results = generator.generate(
                    prompts[ii:ii+chunk_size], 
                    max_gen_len=args.max_gen_len, 
                    temperature=args.temperature, 
                    top_p=args.top_p
                )
                time1 = time.time()
                # time chunk
                speed = chunk_size / (time1 - time0)
                eta = (len(prompts) - ii) / speed
                eta = time.strftime("%Hh%Mm%Ss", time.gmtime(eta)) 
                all_times.append(time1 - time0)
                print(f"Generated {ii:5d} - {ii+chunk_size:5d} - Speed {speed:.2f} prompts/s - ETA {eta}")
                # log
                #for prompt, result in zip(prompts[ii:ii+chunk_size], results):
                for bidx, (prompt, result) in enumerate(zip(prompts[ii:ii+chunk_size], results)):
                    r = {
                        "prompt": prompt, 
                        "result": result[len(prompt):],
                        "speed": speed,
                        "eta": eta}
                    if "mkey" in usermode or "rkey" in usermode or "nchars" in usermode or "brick3" in usermode or "mix" in usermode:
                        if type(generator.key01) is list:
                            r.update({"key01" : generator.key01[bidx]})
                        else:
                            r.update({"key01" : generator.key01})
                    f.write(json.dumps(r) + "\n")
                    f.flush()
        print(f"Average time per prompt: {np.sum(all_times) / (len(prompts) - start_point) :.2f}")

    if "eval" not in extramode:
        if args.method_detect == 'same':
            args.method_detect = args.method
        if (not args.do_eval) or (args.method_detect not in ["openai", "maryland", "marylandz", "openaiz", "openainp", None]):
            return
        
        if "dkey" in usermode:
            args.hash_key = usermode["dkey"]
        if "dkey" in extramode:
            args.hash_key = extramode["dkey"]
        print(f"Detect with {args.method_detect}")

        # build watermark detector
        if args.method_detect == "openai" or args.method_detect is None:
            detector = OpenaiDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
        elif args.method_detect == "openaiz":
            detector = OpenaiDetectorZ(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
        elif args.method_detect == "openainp":
            detector = OpenaiNeymanPearsonDetector(model, tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
        elif args.method_detect == "maryland":
            detector = MarylandDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta)
        elif args.method_detect == "marylandz":
            detector = MarylandDetectorZ(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta)
        detector.args = args

        # build sbert model
        sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
        cossim = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
        results_orig = load_results(json_path=args.prompt_path, nsamples=args.nsamples, result_key="output")
        if args.split is not None:
            results_orig = results_orig[left:right]

        #=========================== evaluate ===============================
        if "useori" in usermode:
            results = results_orig
        else:
            results = load_results(json_path=os.path.join(args.output_dir, f"results.jsonl"), nsamples=args.nsamples, result_key="result")
        log_stats = []
        if "mkey" in usermode or "rkey" in usermode or "chars" in usermode or "nchars" in usermode or "brick3" in usermode:
            mscores_list = []
        if "rkey" in usermode or "nchars" in usermode or "brick3" in usermode or "mix" in usermode:
            key01s = load_results(json_path=os.path.join(args.output_dir, f"results.jsonl"), nsamples=args.nsamples, result_key="key01")
        if "wowm" in usermode:
            key01s = [0] * args.nsamples

        text_index = left if args.split is not None else 0
        if "chars" in usermode or "nchars" in usermode or "brick3" in usermode or "wowm" in usermode or "mix" in usermode:
            parallel = True
            if "dbgdetector" in usermode:
                parallel = False
            if parallel:
                import concurrent.futures
                # 初始化 mscores_list
                mscores_list = []

                # 创建独立的 tqdm 进度条
                total_tasks = len(results)
                progress_bar = tqdm.tqdm(total=total_tasks, desc="Processing texts")

                with concurrent.futures.ProcessPoolExecutor(max_workers= min(os.cpu_count(), 48) - 1) as executor:#
                    try:
                        futures = [
                            executor.submit(process_text, text, text_orig, tokenizer, args, key01, usermode) 
                            for text, text_orig, key01 in zip(results, results_orig, key01s)
                        ]
                        
                        # 使用 tqdm 包装 futures
                        for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing texts"):
                                result = future.result()
                                if "chars" in usermode:
                                    mscores_list.append([usermode["chars"], result[0]])
                                else:
                                    mscores_list.append([result[1], result[0]])
                    except Exception as exc:
                        print(f'Text processing generated an exception: {exc}')
                        traceback.print_exc()

                # 关闭进度条
                progress_bar.close()
            else:
                mscores_list = []
                for text, text_orig, key01 in tqdm.tqdm(zip(results, results_orig, key01s)):
                    result = process_text(text, text_orig, tokenizer, args, key01, usermode)
                    mscores_list.append([result[1], result[0]])
                

            results = results_orig = []# avoid origional runing
        with open(os.path.join(args.output_dir, 'scores.jsonl'), 'w') as f:
            for text, text_orig in tqdm.tqdm(zip(results, results_orig)):
                # compute watermark score
                if args.method_detect == "openainp":
                    scores_no_aggreg, probs = detector.get_scores_by_t([text], scoring_method=args.scoring_method, payload_max=args.payload_max)
                    scores = detector.aggregate_scores(scores_no_aggreg) # p 1
                    pvalues = detector.get_pvalues(scores_no_aggreg, probs)
                else:
                    scores_no_aggreg = detector.get_scores_by_t([text], scoring_method=args.scoring_method, payload_max=args.payload_max)
                    if "mkey" in usermode or "rkey" in usermode:
                        scores_no_aggreg, seeds = scores_no_aggreg
                        seed_scores = group_scores_by_seeds(seeds[0], scores_no_aggreg[0])
                        mscores = {seed : detector.aggregate_scores([seed_scores[seed]]) for seed in seed_scores}
                        mpvalues = {seed : detector.get_pvalues([seed_scores[seed]]) for seed in seed_scores}
                    if "chars" in usermode or "nchars" in usermode:
                        #"""
                        mscores = {}
                        mpvalues = {}
                        for seed in range(256):
                            scores_no_aggreg = detector.get_scores_by_t([text], scoring_method=args.scoring_method, payload_max=args.payload_max, chars_seed = seed)
                            scores_no_aggreg, seeds = scores_no_aggreg
                            seed_scores = group_scores_by_seeds(seeds[0], scores_no_aggreg[0])
                            mscores[seed] = {pos : detector.aggregate_scores([seed_scores[pos]]) for pos in seed_scores}
                            mpvalues[seed] = {pos : detector.get_pvalues([seed_scores[pos]]) for pos in seed_scores}
                    scores = detector.aggregate_scores(scores_no_aggreg) # p 1
                    pvalues = detector.get_pvalues(scores_no_aggreg) 
                if args.payload_max:
                    if "mkey" in usermode or "rkey" in usermode:
                        for seed in mscores:
                            seed_payloads = np.argmin(mpvalues[seed], axis=1).tolist()
                            mscores[seed] = [float(s[payload]) for s,payload in zip(mscores[seed],seed_payloads)]
                    if "chars" in usermode or "nchars" in usermode:
                        for seed in mscores:
                            for pos in mscores[seed]:
                                seed_payloads = np.argmin(mpvalues[seed][pos], axis=1).tolist()
                                mscores[seed][pos] = [float(s[payload]) for s,payload in zip(mscores[seed][pos],seed_payloads)]
                    if "mkey" in usermode:
                        mscores_list.append(mscores)
                    if "rkey" in usermode or "nchars" in usermode:
                        mscores_list.append([key01s[text_index], mscores])
                    if "chars" in usermode:
                        mscores_list.append([usermode["chars"], mscores])
                    # decode payload and adjust pvalues
                    if len(pvalues[0]) == 0:
                        continue
                    payloads = np.argmin(pvalues, axis=1).tolist()
                    all_pvalues = [pvalues[0].tolist()]
                    pvalues = pvalues[:,payloads][0].tolist() # in fact pvalue is of size 1, but the format could be adapted to take multiple text at the same time
                    scores = [float(s[payload]) for s,payload in zip(scores,payloads)]
                    # adjust pvalue to take into account the number of tests (2**payload_max)
                    # use exact formula for high values and (more stable) upper bound for lower values
                    M = args.payload_max+1
                    pvalues = [(1 - (1 - pvalue)**M) if pvalue > min(1 / M, 1e-5) else M * pvalue for pvalue in pvalues]
                else:
                    payloads = [ 0 ] * len(pvalues)
                    pvalues = pvalues[:,0].tolist()
                    all_pvalues = pvalues
                    scores = [float(s[0]) for s in scores]
                num_tokens = [len(score_no_aggreg) for score_no_aggreg in scores_no_aggreg]
                # compute sbert score
                xs = sbert_model.encode([text, text_orig], convert_to_tensor=True)
                score_sbert = cossim(xs[0], xs[1]).item()
                # log stats and write
                log_stat = {
                    'text_index': text_index,
                    'num_token': num_tokens[0],
                    'score': scores[0],
                    'pvalue': pvalues[0], 
                    'all_pvalues': all_pvalues[0],
                    'score_sbert': score_sbert,
                    'payload': payloads[0],
                }
                log_stats.append(log_stat)
                f.write(json.dumps(log_stat)+'\n')
                text_index += 1
                """
                os.system(f"echo {str(mscores)} >> output/exps/$usermode/mscores.jsonl")#"""
            json.dump(mscores_list, open(os.path.join(args.output_dir, 'mscores_list.json'), "w"))

        if "mkey" in usermode:
            sum_dict = {}
            count_dict = {}

            # Iterate over each dictionary in the list
            for d in mscores_list:
                for key, value in d.items():
                    if key in sum_dict:
                        sum_dict[key] += sum(value)
                        count_dict[key] += len(value)
                    else:
                        sum_dict[key] = sum(value)
                        count_dict[key] = len(value)

            mean_dict = {key: sum_dict[key] / count_dict[key] for key in sorted(sum_dict.keys())}

            possible_list = generate_possible_sequences(mean_dict, top_n = -1)
            predict = possible_list[0][0]
            res = {"bitscore" : mean_dict, "predict" : predict}
            print(res)


        if "rkey" in usermode:
            sum_dict = {}
            count_dict = {}

            # Iterate over each dictionary in the list
            for key01, d in mscores_list:
                if key01 not in sum_dict:
                    sum_dict[key01], count_dict[key01] = {}, {}
                for key, value in d.items():
                    if key in sum_dict[key01]:
                        sum_dict[key01][key] += sum(value)
                        count_dict[key01][key] += len(value)
                    else:
                        sum_dict[key01][key] = sum(value)
                        count_dict[key01][key] = len(value)

            # Calculate the mean for each key
            predicts = []
            stats = {"cnt" : 0, "match" : 0, "bitcnt" : 0,  "bitmatch" : 0}
            for key01 in sum_dict:
                mean_dict = {key: sum_dict[key01][key] / count_dict[key01][key] for key in sorted(sum_dict[key01].keys())}

                # Print the result
                #thres = usermode["thres"] if "thres" in usermode else 1.7
                #predict = "".join([str(int(mean_dict[key] > thres)) for key in mean_dict])
                anchorchecklen = args.nchecksum + args.nanchor
                possible_list = generate_possible_sequences(mean_dict, top_n = -1)
                predict = ([pl for pl in possible_list if pl[:anchorchecklen] == generator.anchor + calculate_checksum(pl[anchorchecklen:])] + [possible_list[0]])[0][0]
                res = {"key01" : key01, "bitscore" : mean_dict, "predict" : predict}
                predicts.append(res)
                stats["cnt"] += 1
                stats["match"] += int(predict[anchorchecklen:] == key01[anchorchecklen:])
                stats["bitcnt"] += len(predict) - anchorchecklen
                stats["bitmatch"] += sum(c1 == c2 for c1, c2 in zip(key01[anchorchecklen:], predict[anchorchecklen:]))
            stats["accuracy"] = stats["match"] / stats["cnt"]
            stats["bitaccuracy"] = stats["bitmatch"] / stats["bitcnt"]
            json.dump(predicts, open(os.path.join(args.output_dir, 'predicts.json'), "w"))
            json.dump(stats, open(os.path.join(args.output_dir, 'predicts_scores.json'), "w"))
            print(stats)

        if "chars" in usermode or "nchars" in usermode or "brick3" in usermode:
            sum_dict = {}
            count_dict = {}
            # Iterate over each dictionary in the list
            results = {}
            mean_dict = {}
            for msg, d in mscores_list:
                if msg not in sum_dict:
                    sum_dict[msg], count_dict[msg], mean_dict[msg] = {}, {}, {}
                
                for seed in d:
                    for pos in d[seed]:
                        if pos not in sum_dict[msg]:
                            sum_dict[msg][pos], count_dict[msg][pos], mean_dict[msg][pos] = {}, {}, {}
                        #sum_dict[msg][pos][seed] = sum_dict[msg][pos].get(seed, 0) + sum(d[seed][pos])
                        sum_dict[msg][pos][seed] = sum_dict[msg][pos].get(seed, []) + d[seed][pos]
                        count_dict[msg][pos][seed] = count_dict[msg][pos].get(seed, 0) + 1

                for pos in sum_dict[msg]:
                    for seed in sum_dict[msg][pos]:
                        #mean_dict[msg][pos][seed] = sum_dict[msg][pos][seed] / count_dict[msg][pos][seed]
                        mean_dict[msg][pos][seed] = np.median(sum_dict[msg][pos][seed])

                res = []
                if "brick3" in usermode:
                    maxseed = max(mean_dict[msg][0], key=mean_dict[msg][0].get)
                    uid = maxseed
                    while uid > 0:
                        res.append(chr(uid % 256))
                        uid = uid // 256
                else:
                    for pos in sorted(mean_dict[msg].keys()):
                        maxseed = max(mean_dict[msg][pos], key=mean_dict[msg][pos].get)
                        res.append(chr(maxseed))
                results[msg] = "".join(res)
            print(results)
            stats = {"cnt" : 0, "match" : 0, "bitcnt" : 0,  "bitmatch" : 0}
            for key01 in results:
                predict = results[key01]
                stats["cnt"] += 1
                stats["match"] += int(predict == key01)
                stats["bitcnt"] += len(predict)
                stats["bitmatch"] += sum(c1 == c2 for c1, c2 in zip(key01, predict))
            stats["accuracy"] = stats["match"] / stats["cnt"]
            stats["bitaccuracy"] = stats["bitmatch"] / stats["bitcnt"]
            print(stats)
            exit(0)
        pickle.dump(mscores_list, open(os.path.join(args.output_dir, 'predicts_scores.pkl'), "wb"))
    else:
        mscores_list = pickle.load(open(os.path.join(args.output_dir, 'predicts_scores.pkl'), "rb"))
    if "filtershort" in extramode:
        results = load_results(json_path=os.path.join(args.output_dir, f"results.jsonl"), nsamples=args.nsamples, result_key="result")
        mscores_list = [m for m ,r in zip(mscores_list, results) if len(r) > 100]



    if "wowm" in usermode:
        res = []
        for _, m in mscores_list:
            res.append([m[i][0][0] for i in m])
        json.dump(res, open(os.path.join(args.output_dir, 'scores.json'), "w"))
        exit(0)

    if "mix" in usermode:
        filter_mscores_list = [m for m in mscores_list if len(m[1][0]) > 0]
        def split_dev(lst, dev_ratio=0.5):
            split_method = "Sampling"
            if split_method == "FirstDev":
                dev_size = int(len(lst) * dev_ratio)
                dev = lst[:dev_size]
                test = lst[dev_size:]
            
            elif split_method == "Sampling":
                random.seed(0)
                indices = list(range(len(lst)))
                random.shuffle(indices)
                dev_size = int(len(lst) * dev_ratio)
                dev_indices = indices[:dev_size]
                test_indices = indices[dev_size:]
                dev = [lst[i] for i in dev_indices]
                test = [lst[i] for i in test_indices]
            if "tunetest" in usermode:
                print("Debug tuning on the test set!")
                return test, test
            return dev, test
        def get_dev_thres(score_func):
                threses = [round(i * 0.02, 2) for i in range(0, 401)]
                accus = {}
                dev_list = split_dev(filter_mscores_list)[0]
                if "ckmax" in usermode:
                    parallel = True
                    if parallel:
                        accus = parallel_tune_dev(threses, dev_list, usermode, score_func)
                    else:
                        for thres in tqdm.tqdm(threses):
                            for thresl in threses:
                                for thresu in threses:
                                    keys, predict_keys, scores = score_func((thres, thresl, thresu), dev_list, usermode)
                                    accus[(thres, thresl, thresu)] = sum([x == y for x, y in zip(keys, predict_keys)]) / len(keys)
                                    #accus[(thres, thresl, thresu)] = sum([(x if x is None else 1) == (y if y is None else 1) for x, y in zip(keys, predict_keys)]) / len(keys)
                elif "ckadd" in usermode:
                    for thres in threses:
                        for w in [round(i * 0.1, 2) + 0.2 for i in range(0, 21)]:
                            keys, predict_keys, scores = score_func((thres, w), dev_list, usermode)
                            accus[(thres, w)] = sum([x == y for x, y in zip(keys, predict_keys)]) / len(keys)
                            #accus[(thres, w)] = sum([(x if x is None else 1) == (y if y is None else 1) for x, y in zip(keys, predict_keys)]) / len(keys)
                elif "ckand" in usermode or "ckcmp" in usermode:
                    for thres in threses:
                        for thres1 in threses:
                            keys, predict_keys, scores = score_func((thres, thres1), dev_list, usermode)
                            accus[(thres, thres1)] = sum([x == y for x, y in zip(keys, predict_keys)]) / len(keys)
                            #accus[(thres, thres1)] = sum([(x if x is None else 1) == (y if y is None else 1) for x, y in zip(keys, predict_keys)]) / len(keys)
                else:
                    for thres in threses:
                        keys, predict_keys, scores = score_func(thres, dev_list, usermode)
                        accus[thres] = sum([x == y for x, y in zip(keys, predict_keys)]) / len(keys)
                        #accus[thres] = sum([(x if x is None else 1) == (y if y is None else 1) for x, y in zip(keys, predict_keys)]) / len(keys)
                thres = max(accus, key=accus.get)
                if "ckmax" in usermode:
                    maxval = max(accus.values())
                    mthres = [k for k, v in accus.items() if v == maxval]
                    arr0 = [s[0] for s in mthres]
                    t0id = np.argsort(arr0)[len(arr0) // 2]
                    t0 = arr0[t0id]
                    t1 = min([s[1] for s in mthres if s[0] == t0])
                    t2 = max([s[2] for s in mthres if (s[0], s[1]) == (t0, t1)])
                    thres = (t0, t1, t2)
                elif "ckadd" in usermode:
                    maxval = max(accus.values())
                    mthres = [k for k, v in accus.items() if v == maxval]
                    arr0 = [s[0] for s in mthres]
                    t0id = np.argsort(arr0)[len(arr0) // 2]
                    t0 = arr0[t0id]
                    t1 = float(np.median([s[1] for s in mthres if s[0] == t0]))
                    thres = (t0, t1)
                return thres, split_dev(filter_mscores_list)[1]
        if "alt2" in usermode:
            filter_mscores_list = [{
                "gold" : m[0][1] if m[0] is not None else None, 
                "worn_score" : m[1][0][0][0] if len(m[1][0]) == 2 else -1, 
                "max_id_score" : max([m[1][i][1][0] if len(m[1][i]) == 2 else -1 for i in range(len(m[1]))]), 
                "max_id" : int(np.argmax([m[1][i][1][0] if len(m[1][i]) == 2 else -1 for i in range(len(m[1]))])),
                "id_mean_score" : float(np.mean([m[1][i][1][0] if len(m[1][i]) == 2 else -1 for i in range(len(m[1]))])),
                "id_second_score" : sorted([m[1][i][1][0] if len(m[1][i]) == 2 else -1 for i in range(len(m[1]))], reverse=True)[1]}
            for m in filter_mscores_list]
            thres = usermode.get("thres", 1.4)
            if "devthres" in usermode:
                thres, new_filter_mscores_list  = get_dev_thres(calc_alt2_scores)
                print("Thres:", thres)
                filter_mscores_list = new_filter_mscores_list
            keys, predict_keys, scores = calc_alt2_scores(thres, filter_mscores_list, usermode)

        elif "alt" in usermode:
            filter_mscores_list = [m for m in mscores_list if len(m[1][0]) > args.nanchor]
            keys, predict_keys, anchors = [], [], []
            for m in filter_mscores_list:
                anchors.append([m[1][0][i][0] for i in range(args.nanchor)])
                keys.append(m[0][-1] if m[0] is not None else None)
                s = [m[1][i][args.nanchor][0] for i in range(len(m[1]))]
                predict_keys.append(int(np.argmax(s)))
                #if m[1][0][0][0] < usermode.get("thres", 1.4):
                #if m[1][0][0][0] - m[1][0][1][0] < 0.12:
                if np.mean(anchors[-1]) < usermode.get("thres", 1.4):
                    predict_keys[-1] = None
            anchort = [np.mean([a[i] for i in range(0, len(a), 2)]) for a in anchors]
            anchorf = [np.mean([a[i] for i in range(1, len(a), 2)]) for a in anchors]

            isNone = [t - f < 0.1 for t, f in zip(anchort, anchorf)]


                #if len(m[1][0]) != args.nanchor + 1 or m[1][0][0][0] < usermode.get("thres", 1.4):
                #    pass
        else:
            def calc_base_scores(thres, filter_mscores_list, usermode):
                sample_key_score = [[m[1][i][i][0] for i in range(len(m[1]))] for m in filter_mscores_list]
                keys = [m[0] for m in filter_mscores_list]
                if "mmbase" in usermode:
                    mean_score = [np.mean([m[1][i][i][0] for i in range(len(m[1]))]) for m in filter_mscores_list]
                    predict_keys = [int(np.argmax(m)) if np.max(m) - s > thres else None for s, m in zip(mean_score, sample_key_score)]
                if "msndbase" in usermode:
                    snd_score = [sorted([m[1][i][i][0] for i in range(len(m[1]))], reverse=True)[1] for m in filter_mscores_list]
                    predict_keys = [int(np.argmax(m)) if np.max(m) - s > thres else None for s, m in zip(snd_score, sample_key_score)]
                else:
                    predict_keys = [int(np.argmax(m)) if np.max(m) > thres else None for m in sample_key_score]
                scores = [float(np.max(m)) for m in sample_key_score]
                return keys, predict_keys, scores
            thres = usermode.get("thres", 1.4)
            if "devthres" in usermode:
                thres, new_filter_mscores_list  = get_dev_thres(calc_base_scores)
                filter_mscores_list = new_filter_mscores_list
            keys, predict_keys, scores = calc_base_scores(thres, filter_mscores_list, usermode)    
        accu = sum([x == y for x, y in zip(keys, predict_keys)]) / len(keys)
        a_none = [x == y for x, y in zip(keys, predict_keys) if x is None]
        accu_None = sum(a_none) / max(1, len(a_none))
        a_uid = [x == y for x, y in zip(keys, predict_keys) if x is not None]
        accu_uid = sum(a_uid) / max(1, len(a_uid))
        a_worn = [(x if x is None else 1) == (y if y is None else 1) for x, y in zip(keys, predict_keys)]
        accu_worn = sum(a_worn) / max(1, len(a_worn))
        a_fuid = [x == y for x, y in zip(keys, predict_keys)]
        accu_fuid = sum(a_fuid) / max(1, len(a_fuid))
        fp_worn = [y is not None for x, y in zip(keys, predict_keys) if x is None]
        fpr_worn = sum(fp_worn) / max(1, len(fp_worn))
        res = {
            "keys" : keys,
            "predict_keys" : predict_keys,
            "scores" : scores,
            "accu" : accu,
            "accu_none" : accu_None,
            "accu_uid" : accu_uid,
            "accu_worn" : accu_worn,
            "accu_fuid" : accu_fuid,
            "fpr_worn": fpr_worn,
            "thres" : thres
        }
        print("accu: ", accu, "accu_worn: ", accu_worn)
        if "auc" in usermode:
            res.update({"auc" : auc})
        json.dump(res, open(os.path.join(args.output_dir, 'predicts_scores.json'), "w"))
        exit(0)
        

    df = pd.DataFrame(log_stats)
    df['log10_pvalue'] = np.log10(df['pvalue'])
    print(f">>> Scores: \n{df.describe(percentiles=[])}")
    print(f"Saved scores to {os.path.join(args.output_dir, 'scores.csv')}")


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    hyperparse.reset_hyper(usermode, args)
    hyperparse.reset_hyper(extramode, args)
    main(args)
